{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SimpleNet - Development Notebook\n",
    "\n",
    "Notebook to try different parameters for SimpleNet.\n",
    "\n",
    "## Base architecture\n",
    "\n",
    "<img src=\"../../../doc/assets/simplenet_arch.png\"/>\n",
    "\n",
    "Note: Diagram generated with [NN-SVG by Alex Nail](http://alexlenail.me/NN-SVG/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "WORKSPACE_BASE_PATH=\"/tf/notebooks/\" # Parent directory containing src, checkpoints, models, etc.\n",
    "CODE_BASE_PATH=\"/tf/notebooks/src/\" # Path were components are stored.\n",
    "DATA_BASE_PATH=\"/tf/notebooks/data/\" # Directory with data in case it is not inside WORKSPACE BASE path.\n",
    "sys.path.append(CODE_BASE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Conv2D\n",
    "from tensorflow.keras.layers import BatchNormalization\n",
    "from tensorflow.keras.layers import MaxPooling2D\n",
    "from tensorflow.keras.layers import Flatten\n",
    "from tensorflow.keras.layers import Dropout\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras.layers import ELU\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from tensorflow.keras.optimizers import SGD\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint\n",
    "from tensorflow.keras.callbacks import CSVLogger\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "from tensorflow.keras.utils import plot_model\n",
    "from tensorflow.keras.callbacks import TensorBoard\n",
    "\n",
    "from IPython.display import SVG\n",
    "from tensorflow.keras.utils import plot_model\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_simplenet(input_shape=(64,64,3),n_output_classes=2):\n",
    "    \"\"\"\n",
    "    \"\"\"\n",
    "    model = Sequential()\n",
    "    \n",
    "    # Convolution + Pooling #1\n",
    "    model.add(Conv2D( 32, (3, 3), input_shape=input_shape,\n",
    "                          activation = 'relu' ))        \n",
    "    model.add( MaxPooling2D(pool_size = (2,2)))\n",
    "    \n",
    "    # Convolution + Pooling #2\n",
    "    model.add(Conv2D( 32, (3, 3), activation = 'relu' ))        \n",
    "    model.add( MaxPooling2D(pool_size = (2,2)))\n",
    "    \n",
    "    # Flattening\n",
    "    model.add( Flatten() )\n",
    "    \n",
    "    # FC #1\n",
    "    model.add( Dense( units = 128, activation = 'relu' ) )\n",
    "    \n",
    "    # Output Layer\n",
    "    model.add( Dense( units = n_output_classes, activation = 'softmax' ) )   \n",
    "    \n",
    "    # Compile\n",
    "    model.compile( \n",
    "        optimizer = 'adam', loss = 'categorical_crossentropy',\n",
    "        metrics = ['accuracy'] )\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_simplenet()\n",
    "model.summary()\n",
    "plot_model(model, show_shapes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_simplenet( model,\n",
    "               target_size,\n",
    "               dataset_path,\n",
    "               training_path_prefix,\n",
    "               test_path_prefix,                        \n",
    "               history_file_path,\n",
    "               history_filename,\n",
    "               checkpoint_path,\n",
    "               checkpoint_prefix,\n",
    "               number_of_epochs,\n",
    "               tensorboard_log_path\n",
    "            ):\n",
    "    \"\"\"\n",
    "        see: https://keras.io/preprocessing/image/\n",
    "    \"\"\"\n",
    "    train_datagen = ImageDataGenerator( rescale=1./255,\n",
    "                                        shear_range=0.2,\n",
    "                                        zoom_range=0.2,\n",
    "                                        horizontal_flip=True )\n",
    "        \n",
    "    test_datagen = ImageDataGenerator(rescale=1./255)\n",
    "    training_set_generator = train_datagen.flow_from_directory(\n",
    "            dataset_path+training_path_prefix,\n",
    "            target_size,\n",
    "            batch_size=32,\n",
    "            class_mode='categorical',\n",
    "            shuffle=True,\n",
    "            seed=42\n",
    "        )\n",
    "    test_set_generator = test_datagen.flow_from_directory(\n",
    "            dataset_path+test_path_prefix,\n",
    "            target_size,\n",
    "            batch_size=32,\n",
    "            class_mode='categorical',\n",
    "            shuffle=True,\n",
    "            seed=42\n",
    "        )\n",
    "        \n",
    "    step_size_train=training_set_generator.n//training_set_generator.batch_size\n",
    "    step_size_validation=test_set_generator.n//test_set_generator.batch_size\n",
    "\n",
    "    check_pointer = ModelCheckpoint(\n",
    "            checkpoint_path + '%s_weights.{epoch:02d}-{val_loss:.2f}.hdf5' % checkpoint_prefix, \n",
    "            monitor='val_loss', \n",
    "            mode='auto', \n",
    "            save_best_only=True\n",
    "    )\n",
    "    \n",
    "    tensorboard_logger = TensorBoard( \n",
    "        log_dir=tensorboard_log_path, histogram_freq=0,  \n",
    "          write_graph=True, write_images=True\n",
    "    )\n",
    "    tensorboard_logger.set_model(model)\n",
    "\n",
    "    csv_logger = CSVLogger(filename=history_file_path+history_filename)\n",
    "    history = model.fit_generator(\n",
    "            training_set_generator,\n",
    "            steps_per_epoch=step_size_train,\n",
    "            epochs=number_of_epochs,\n",
    "            validation_data=test_set_generator,\n",
    "            validation_steps=step_size_validation,\n",
    "            callbacks=[check_pointer, csv_logger,tensorboard_logger] \n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME=\"simplenet_cracks8020\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cracknet(  model,\n",
    "            target_size=(120,120),\n",
    "            dataset_path=DATA_BASE_PATH+\"/datasets/cracks_splitted8020/\",\n",
    "            training_path_prefix=\"train_set\",\n",
    "            test_path_prefix=\"test_set\",\n",
    "            history_file_path=WORKSPACE_BASE_PATH+\"/training_logs/\",\n",
    "            history_filename=MODEL_NAME+\".csv\",\n",
    "            checkpoint_path=WORKSPACE_BASE_PATH+\"/model-checkpoints/\",\n",
    "            checkpoint_prefix=MODEL_NAME,\n",
    "            number_of_epochs=30, \n",
    "            tensorboard_log_path=WORKSPACE_BASE_PATH+\"/tensorboard_logs/\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "from utils.plotutils import plot_learning_curves_from_history_file\n",
    "%matplotlib inline\n",
    "matplotlib.rcParams['figure.figsize'] = (10, 6)\n",
    "fig = plot_learning_curves_from_history_file(WORKSPACE_BASE_PATH+\"/training_logs/\"+MODEL_NAME+\".csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export model to SavedModelFormat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "\n",
    "list_of_files = glob.glob(WORKSPACE_BASE_PATH+'/model-checkpoints/*.hdf5') \n",
    "CHECKPOINT_FILE = max(list_of_files, key=os.path.getctime) # last checkpoint\n",
    "VERSION=1\n",
    "print(CHECKPOINT_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.export_tools import convert_from_keras_to_savedmodel\n",
    "\n",
    "convert_from_keras_to_savedmodel(\n",
    "    input_filename=CHECKPOINT_FILE,\n",
    "    export_path=WORKSPACE_BASE_PATH+'/models/'+MODEL_NAME+\"/\"+str(VERSION)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
